{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using Pòlya-Gamma Auxiliary Variables for Binary Classification\n",
    "\n",
    "## Overview\n",
    "\n",
    "In this notebook, we'll demonstrate how to use Pòlya-Gamma auxiliary variables to do efficient inference for Gaussian Process binary classification as in reference [1]. \n",
    "We will also use natural gradient descent, as described in more detail in the [Natural gradient descent](./Natural_Gradient_Descent.ipynb) tutorial.\n",
    "\n",
    "\n",
    "[1] Florian Wenzel, Theo Galy-Fajou, Christan Donner, Marius Kloft, Manfred Opper. [Efficient Gaussian process classification using Pòlya-Gamma data augmentation](https://arxiv.org/abs/1802.06383). Proceedings of the AAAI Conference on Artificial Intelligence. 2019.\n",
    "\n",
    "## Pòlya-Gamma Augmentation\n",
    "\n",
    "When a Gaussian Process prior is paired with a Gaussian likelihood inference can be done exactly with a simple closed form expression.\n",
    "Unfortunately this attractive feature does not carry over to non-conjugate likelihoods like the Bernoulli likelihood that arises in the context of binary classification with a logistic link function.\n",
    "Sampling-based stochastic variational inference offers a general strategy for dealing with non-conjugate likelihoods; see the [corresponding tutorial](./Non_Gaussian_Likelihoods.ipynb).\n",
    "\n",
    "Another possible strategy is to introduce additional latent variables that restore conjugacy. \n",
    "This is the strategy we follow here. \n",
    "In particular we are going to introduce a Pòlya-Gamma auxiliary variable for each data point in our training dataset. \n",
    "The [Polya-Gamma](https://arxiv.org/abs/1205.0310) distribution $\\rm{PG}$ is a univariate distribution with support on the positive real line. \n",
    "In our context it is interesting because if $\\omega_i$ is distributed according to $\\rm{PG}(1,0)$ then the logistic likelihood $\\sigma(\\cdot)$ for data point $(x_i, y_i)$ can be represented as\n",
    "\n",
    "\\begin{align}\n",
    "\\sigma(y_i f_i) = \\frac{1}{1 + \\exp(-y_i f_i)} = \\tfrac{1}{2} \\mathbb{E}_{\\omega_i \\sim \\rm{PG}(1,0)} \\left[ \\exp \\left(\\tfrac{1}{2} y_i f_i - \\tfrac{\\omega_i}{2} f_i^2 \\right) \\right]\n",
    "\\end{align}\n",
    "\n",
    "where $y_i \\in \\{-1, 1\\}$ is the binary label of data point $i$\n",
    "and $f_i$ is the Gaussian Process prior evaluated at input $x_i$. \n",
    "The crucial point here is that $f_i$ appears quadratically in the exponential within the expectation. \n",
    "In other words, conditioned on $\\omega_i$, we can integrate out $f_i$ exactly, just as if we were doing regression with a Gaussian likelihood. For more details please see the original reference. \n",
    "\n",
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# Make plots inline\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For this example notebook, we'll create a simple artificial dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from math import floor\n",
    "\n",
    "# this is for running the notebook in our testing framework\n",
    "smoke_test = ('CI' in os.environ)\n",
    "\n",
    "N = 100\n",
    "X = torch.linspace(-1., 1., N)\n",
    "probs = (torch.sin(X * math.pi).add(1.).div(2.))\n",
    "y = torch.distributions.Bernoulli(probs=probs).sample()\n",
    "X = X.unsqueeze(-1)\n",
    "\n",
    "train_n = int(floor(0.8 * N))\n",
    "indices = torch.randperm(N)\n",
    "train_x = X[indices[:train_n]].contiguous()\n",
    "train_y = y[indices[:train_n]].contiguous()\n",
    "\n",
    "test_x = X[indices[train_n:]].contiguous()\n",
    "test_y = y[indices[train_n:]].contiguous()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's plot our artificial dataset. \n",
    "Note that here the binary labels are 0/1-valued; we will need to be careful to translate between this representation and the -1/1 representation that is most natural in the context of Pòlya-Gamma augementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f7ee8dd4d90>]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAATwklEQVR4nO3df4wcZ33H8c+n54t0KTRO8AGJ7eBUCga3kALbmEILSSn4R0vNj1ayQUBSJMsSRvBHLYyqUqSoAhpRkSqhxk0toKrif+KGNDJ1aYEiEYF8zi/HBAcTfvh8aXIBEgqxGtt8+8eO0/F6Z/bZ29m7vYf3Szrd7swzz3znmcefzO3OZh0RAgAsfr+y0AUAAJpBoANAJgh0AMgEgQ4AmSDQASATSxZqx8uWLYtVq1Yt1O4BYFE6dOjQExEx2W3dggX6qlWrNDU1tVC7B4BFyfYPqtbxkgsAZIJAB4BMEOgAkAkCHQAyQaADQCZ63uVie4+kP5L0eET8Zpf1lnSTpI2SnpZ0XUTc03ShknTHvSd044GjmnnypC5bOqFrXzKpr3x79tnnO9at1ltesfycdksvHFeE9NTJU+e0QT4650VT57jc70UT47KlJ58+dc6cKi+vmoNzqalq+5S5XbfvlGOqOr6U425yPKvapI5HSj6k1J5yLur6mc88cq//26Lt10n6maTPVwT6RknvVzvQ10q6KSLW9tpxq9WKfm5bvOPeE/rwvsM6eepMZZuJ8TG9/VXLdfuhE5XtJsbH9LG3vYxQz0S3edHEOU6Zb91UzcF+aqo6ppS5XbdvSXM6phS9jm+u41m3v17jkbpNXe39nItu/fQ67rnMVduHIqLVdV3K/z7X9ipJd1UE+mckfTUibiueH5V0TUQ8Wtdnv4H+2o9/WSeePNmz3ZitMz2OafnSCX195+8n7xujq2peDHqOU+dbN1VzMLWmqn2nzO26fUua8zGlqDu+QcazSsp4pG5TVXu/56Kzn5Tj7neu1gV6Ex8sWi7peOn5dLHsvEC3vVXSVkm6/PLL+9rJTOJkSDnBqX1h9FWdy0HP8SDbV83B1D6r2qXM7UH3PYi6fQxj//2Ged02/c6j1H5SjrvJsWniTVF3Wdb1aCNid0S0IqI1Odn1k6uVLiuuMHoZc7dy5tYXRl/VuRz0HA+yfdUcTO2zql3K3K7b97DnfV3/w9h3ynikbtPvPErtJ+W4mxybJgJ9WtLK0vMVkmYa6PccO9at1sT4WG2bifExbVm7srbdxPiYdqxb3XR5WCDd5kUT5zhlvnVTNQf7qanqmFLmdt2+53pMKXodX9P7ThmP1G3qau/nXHTrp9dxN51HTbzkcqek7bb3qv2m6FO9Xj+fi7NvGqTc5dJ60SXc5fJLotu8aOIcd/bb790e5TnYb011x5Qyt3vteyHucpnLeKbc5VI3Hin50Kv21HNR1U/n9qNwl8ttkq6RtEzSY5L+StK4JEXEruK2xZslrVf7tsXrI6Lnu539vikKABjwTdGI2NJjfUh63xxrAwA0hE+KAkAmCHQAyASBDgCZINABIBMEOgBkgkAHgEwQ6ACQCQIdADJBoANAJgh0AMgEgQ4AmSDQASATBDoAZIJAB4BMEOgAkAkCHQAyQaADQCYIdADIBIEOAJkg0AEgEwQ6AGSCQAeATBDoAJAJAh0AMkGgA0AmCHQAyASBDgCZINABIBMEOgBkgkAHgEwQ6ACQiaRAt73e9lHbx2zv7LL+Itv/avt+20dsX998qQCAOj0D3faYpFskbZC0RtIW22s6mr1P0rci4ipJ10j6pO0LGq4VAFAj5Qr9aknHIuKRiHhG0l5JmzrahKTn2rak50j6saTTjVYKAKiVEujLJR0vPZ8ulpXdLOmlkmYkHZb0gYj4RWdHtrfanrI9NTs7O8eSAQDdpAS6uyyLjufrJN0n6TJJvyXpZtu/dt5GEbsjohURrcnJyb6LBQBUSwn0aUkrS89XqH0lXna9pH3RdkzS9yS9pJkSAQApUgL9oKQrbV9RvNG5WdKdHW1+KOkNkmT7BZJWS3qkyUIBAPWW9GoQEadtb5d0QNKYpD0RccT2tmL9Lkk3SPqs7cNqv0TzoYh4Yoh1AwA69Ax0SYqI/ZL2dyzbVXo8I+lNzZYGAOgHnxQFgEwQ6ACQCQIdADJBoANAJgh0AMgEgQ4AmSDQASATBDoAZIJAB4BMEOgAkAkCHQAyQaADQCYIdADIBIEOAJkg0AEgEwQ6AGSCQAeATBDoAJAJAh0AMkGgA0AmCHQAyASBDgCZINABIBMEOgBkgkAHgEwQ6ACQCQIdADJBoANAJgh0AMgEgQ4AmUgKdNvrbR+1fcz2zoo219i+z/YR2//VbJkAgF6W9Gpge0zSLZLeKGla0kHbd0bEt0ptlkr6tKT1EfFD288fVsEAgO5SrtCvlnQsIh6JiGck7ZW0qaPNOyTti4gfSlJEPN5smQCAXlICfbmk46Xn08WyshdLutj2V20fsv3ubh3Z3mp7yvbU7Ozs3CoGAHSVEujusiw6ni+R9CpJfyhpnaS/tP3i8zaK2B0RrYhoTU5O9l0sAKBaz9fQ1b4iX1l6vkLSTJc2T0TEzyX93PbXJF0l6eFGqgQA9JRyhX5Q0pW2r7B9gaTNku7saPMFSb9ne4ntCyWtlfRQs6UCAOr0vEKPiNO2t0s6IGlM0p6IOGJ7W7F+V0Q8ZPvfJD0g6ReSbo2IB4dZOADgXI7ofDl8frRarZiamlqQfQPAYmX7UES0uq3jk6IAkAkCHQAyQaADQCYIdADIBIEOAJkg0AEgEwQ6AGSCQAeATBDoAJAJAh0AMkGgA0AmCHQAyASBDgCZINABIBMEOgBkgkAHgEwQ6ACQCQIdADJBoANAJgh0AMgEgQ4AmSDQASATBDoAZIJAB4BMEOgAkAkCHQAyQaADQCYIdADIBIEOAJkg0AEgEwQ6AGQiKdBtr7d91PYx2ztr2v227TO2/6S5EgEAKXoGuu0xSbdI2iBpjaQtttdUtPuEpANNFwkA6C3lCv1qScci4pGIeEbSXkmburR7v6TbJT3eYH0AgEQpgb5c0vHS8+li2bNsL5f0Vkm76jqyvdX2lO2p2dnZfmsFANRICXR3WRYdzz8l6UMRcaauo4jYHRGtiGhNTk6m1ggASLAkoc20pJWl5yskzXS0aUnaa1uSlknaaPt0RNzRSJUAgJ5SAv2gpCttXyHphKTNkt5RbhARV5x9bPuzku4izAFgfvUM9Ig4bXu72nevjEnaExFHbG8r1te+bg4AmB8pV+iKiP2S9ncs6xrkEXHd4GUBAPrFJ0UBIBMEOgBkgkAHgEwQ6ACQCQIdADJBoANAJgh0AMgEgQ4AmSDQASATBDoAZIJAB4BMEOgAkAkCHQAyQaADQCYIdADIBIEOAJkg0AEgEwQ6AGSCQAeATBDoAJAJAh0AMkGgA0AmCHQAyASBDgCZINABIBMEOgBkgkAHgEwQ6ACQCQIdADJBoANAJpIC3fZ620dtH7O9s8v6d9p+oPi52/ZVzZcKAKjTM9Btj0m6RdIGSWskbbG9pqPZ9yS9PiJeLukGSbubLhQAUC/lCv1qScci4pGIeEbSXkmbyg0i4u6I+Enx9BuSVjRbJgCgl5RAXy7peOn5dLGsynslfbHbCttbbU/ZnpqdnU2vEgDQU0qgu8uy6NrQvlbtQP9Qt/URsTsiWhHRmpycTK8SANDTkoQ205JWlp6vkDTT2cj2yyXdKmlDRPyomfIAAKlSrtAPSrrS9hW2L5C0WdKd5Qa2L5e0T9K7IuLh5ssEAPTS8wo9Ik7b3i7pgKQxSXsi4ojtbcX6XZI+Iul5kj5tW5JOR0RreGUDADo5ouvL4UPXarViampqQfYNAIuV7UNVF8x8UhQAMkGgA0AmCHQAyASBDgCZINABIBMEOgBkgkAHgEwQ6ACQCQIdADJBoANAJgh0AMgEgQ4AmSDQASATBDoAZIJAB4BMEOgAkAkCHQAyQaADQCYIdADIBIEOAJkg0AEgEwQ6AGSCQAeATBDoAJAJAh0AMkGgA0AmCHQAyASBDgCZINABIBMEOgBkgkAHgEwsSWlke72kmySNSbo1Ij7esd7F+o2SnpZ0XUTc03Ctte6494RuPHBUM0+e1GVLJ7Rj3Wq95RXLk9pfNDEuW3ry6VPnPK7rp7z90gvHFSE9dfLc7cvLy31V7buqn9Q6qo6jah9Vjzv3fe1LJvWVb8+et4+qY0rpN3V/qTX2qiNlPDr3nTKPmpyb/c7hJvedu0HyoWp+zfe5TOGIqG9gj0l6WNIbJU1LOihpS0R8q9Rmo6T3qx3oayXdFBFr6/pttVoxNTU1WPWFO+49oQ/vO6yTp848u2xifEwfe9vLKge7s32Vbv30s31nX29/1XLdfujEnLZtqo4mDXJMi6GOunmUInVu9juHm9x37prIh6r5NV/nssz2oYhodVuX8pLL1ZKORcQjEfGMpL2SNnW02STp89H2DUlLbV86UNV9uPHA0fP+EZ88dUY3Hjia3L5Kt3762b6zr9u+eXzO2zZVR5MGOabFUEfdPEqROjf7ncNN7jt3TeRD1fyar3OZKiXQl0s6Xno+XSzrt41sb7U9ZXtqdna231orzTx5spHlqf33u33ZmR5/Ec1XHU0a5JiaNKw6Bhnn1DnY1Fwddp+LUVP5UDW/5uNcpkoJdHdZ1nlkKW0UEbsjohURrcnJyZT6kly2dKKR5an997t92Zi7DdX819GkQY6pScOqY5BxTp2DTc3VYfe5GDWVD1Xzaz7OZaqUQJ+WtLL0fIWkmTm0GZod61ZrYnzsnGUT42PasW51cvsq3frpZ/vOvrasXTnnbZuqo0mDHNNiqKNuHqVInZv9zuEm9527JvKhan7N17lMlXKXy0FJV9q+QtIJSZslvaOjzZ2Sttveq/abok9FxKONVlrj7BsNqe8qd7bv9y6Xzu37vcul9aJLGrnLJfU45uMul/IxLeRdLlV1LNRdLqlzs9853OS+czdoPlTNr/k8l6l63uUiPXsXy6fUvm1xT0T8te1tkhQRu4rbFm+WtF7t2xavj4jaW1iavMsFAH5Z1N3lknQfekTsl7S/Y9mu0uOQ9L5BigQADIZPigJAJgh0AMgEgQ4AmSDQASATSXe5DGXH9qykH/SxyTJJTwypnEGNam2jWpc0urVRV/9GtbZRrUsarLYXRUTXT2YuWKD3y/ZU1a06C21UaxvVuqTRrY26+jeqtY1qXdLwauMlFwDIBIEOAJlYTIG+e6ELqDGqtY1qXdLo1kZd/RvV2ka1LmlItS2a19ABAPUW0xU6AKAGgQ4AmRipQLf9p7aP2P6F7cpbemyvt33U9jHbO0vLL7H9JdvfKX5f3FBdPfu1vdr2faWfn9r+YLHuo7ZPlNZtbKKu1NqKdt+3fbjY/1S/2w+jLtsrbX/F9kPFef9AaV2jY1Y1Z0rrbfvvivUP2H5l6raDSqjtnUVND9i+2/ZVpXVdz+s81XWN7adK5+gjqdvOQ207SnU9aPuM7UuKdcMcsz22H7f9YMX64c6ziBiZH0kvlbRa0lcltSrajEn6rqRfl3SBpPslrSnW/Y2kncXjnZI+0VBdffVb1Pjfan8AQJI+KunPhzRmSbVJ+r6kZYMeW5N1SbpU0iuLx89V+8vIz57Lxsasbs6U2myU9EW1v33r1ZK+mbrtPNT2GkkXF483nK2t7rzOU13XSLprLtsOu7aO9m+W9OVhj1nR9+skvVLSgxXrhzrPRuoKPSIeiohe36Ra96XVmyR9rnj8OUlvaai0fvt9g6TvRkQ/n4Sdq0GPecHGLCIejYh7isf/I+khdfku2gYM8kXnKdsOtbaIuDsiflI8/Yba3wg2bIMc94KPWYctkm5rcP+VIuJrkn5c02So82ykAj1R3RdSvyCKb0oqfj+/oX322+9mnT+Bthd/Yu1p6mWNPmsLSf9u+5DtrXPYflh1SZJsr5L0CknfLC1uaswG+aLzpC9AH3JtZe9V+wrvrKrzOl91/Y7t+21/0fZv9LntsGuT7QvV/uKd20uLhzVmKYY6z5K+4KJJtv9D0gu7rPqLiPhCShddlg1872VdXX32c4GkP5b04dLiv5d0g9p13iDpk5L+bJ5re21EzNh+vqQv2f52cTUxZw2O2XPU/gf3wYj4abF4oDHr3EWXZalfdD6U+Zaw3/Mb2teqHei/W1rc+Hnto6571H5Z8WfFexx3SLoycdth13bWmyV9PSLKV83DGrMUQ51n8x7oEfEHA3ZR94XUj9m+NCIeLf6MebyJumz30+8GSfdExGOlvp99bPsfJN2VWldTtUXETPH7cdv/ovafeF/TAo+Z7XG1w/yfI2Jfqe+BxqzDIF90fkHCtoNI+oJ12y+XdKukDRHxo7PLa87r0Osq/cdXEbHf9qdtL0vZdti1lZz31/IQxyzFUOfZYnzJ5dkvrS6uhjer/SXVKn6/p3j8HkkpV/wp+un3vNfrikA7662Sur4DPqzabP+q7eeefSzpTaUaFmzMbFvSP0p6KCL+tmNdk2NWN2fK9b67uAvh1fr/LzpP2XYQPfu3fbmkfZLeFREPl5bXndf5qOuFxTmU7avVzpMfpWw77NqKmi6S9HqV5t6QxyzFcOfZMN7pneuP2v9wpyX9r6THJB0oll8maX+p3Ua174j4rtov1Zxd/jxJ/ynpO8XvSxqqq2u/Xeq6UO0JfVHH9v8k6bCkB4qTdGmDY9azNrXfOb+/+DkyKmOm9ksHUYzLfcXPxmGMWbc5I2mbpG3FY0u6pVh/WKW7rKrmW4PnsFdtt0r6SWmMpnqd13mqa3ux3/vVfrP2NaMyZsXz6yTt7dhu2GN2m6RHJZ1SO8veO5/zjI/+A0AmFuNLLgCALgh0AMgEgQ4AmSDQASATBDoAZIJAB4BMEOgAkIn/A4GXDNquEm1LAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(train_x.squeeze(-1).cpu(), train_y.cpu(), 'o')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following steps create the dataloader objects. See the [SVGP regression notebook](./SVGP_Regression_CUDA.ipynb) for details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "train_dataset = TensorDataset(train_x, train_y)\n",
    "train_loader = DataLoader(train_dataset, batch_size=100000, shuffle=False)\n",
    "\n",
    "test_dataset = TensorDataset(test_x, test_y)\n",
    "test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Variational Inference with PG Auxiliaries\n",
    "\n",
    "We define a Bernoulli likelihood that leverages Pòlya-Gamma augmentation. \n",
    "It turns out that we can derive closed form updates for the Pòlya-Gamma auxiliary variables. To deal with the Gaussian Process we introduce inducing points and inducing locations. \n",
    "In particular we will need to learn a variational covariance matrix and a variational mean vector that control the inducing points. (See the discussion in the [SVGP tutorial](Approximate_GP_Objective_Functions.ipynb) for more details.) \n",
    "We will use natural gradient updates to deal with these two variational parameters; this will allow us to take large steps, thus yielding fast convergence. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PGLikelihood(gpytorch.likelihoods._OneDimensionalLikelihood):\n",
    "    # this method effectively computes the expected log likelihood \n",
    "    # contribution to Eqn (10) in Reference [1].\n",
    "    def expected_log_prob(self, target, input, *args, **kwargs):\n",
    "        mean, variance = input.mean, input.variance\n",
    "        # Compute the expectation E[f_i^2]\n",
    "        raw_second_moment = variance + mean.pow(2)\n",
    "\n",
    "        # Translate targets to be -1, 1\n",
    "        target = target.to(mean.dtype).mul(2.).sub(1.)\n",
    "\n",
    "        # We detach the following variable since we do not want\n",
    "        # to differentiate through the closed-form PG update.\n",
    "        c = raw_second_moment.detach().sqrt()\n",
    "        # Compute mean of PG auxiliary variable omega: 0.5 * Expectation[omega]\n",
    "        # See Eqn (11) and Appendix A2 and A3 in Reference [1] for details.\n",
    "        half_omega = 0.25 * torch.tanh(0.5 * c) / c\n",
    "\n",
    "        # Expected log likelihood\n",
    "        res = 0.5 * target * mean - half_omega * raw_second_moment\n",
    "        # Sum over data points in mini-batch\n",
    "        res = res.sum(dim=-1)\n",
    "\n",
    "        return res\n",
    "    \n",
    "    # define the likelihood\n",
    "    def forward(self, function_samples):\n",
    "        return torch.distributions.Bernoulli(logits=function_samples)\n",
    "    \n",
    "    # define the marginal likelihood using Gauss Hermite quadrature\n",
    "    def marginal(self, function_dist):\n",
    "        prob_lambda = lambda function_samples: self.forward(function_samples).probs\n",
    "        probs = self.quadrature(prob_lambda, function_dist)\n",
    "        return torch.distributions.Bernoulli(probs=probs)\n",
    "    \n",
    "\n",
    "# define the actual GP model (kernels, inducing points, etc.)  \n",
    "class GPModel(gpytorch.models.ApproximateGP):\n",
    "    def __init__(self, inducing_points):\n",
    "        variational_distribution = gpytorch.variational.NaturalVariationalDistribution(inducing_points.size(0))\n",
    "        variational_strategy = gpytorch.variational.VariationalStrategy(\n",
    "            self, inducing_points, variational_distribution, learn_inducing_locations=True\n",
    "        )\n",
    "        super(GPModel, self).__init__(variational_strategy)\n",
    "        self.mean_module = gpytorch.means.ZeroMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "\n",
    "# we initialize our model with M = 30 inducing points\n",
    "M = 30\n",
    "inducing_points = torch.linspace(-2., 2., M, dtype=train_x.dtype, device=train_x.device).unsqueeze(-1)\n",
    "model = GPModel(inducing_points=inducing_points)\n",
    "model.covar_module.base_kernel.initialize(lengthscale=0.2)\n",
    "likelihood = PGLikelihood()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()\n",
    "    likelihood = likelihood.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup optimizers\n",
    "\n",
    "We will use a `NGD` (Natural Gradient Descent) optimizer to deal with the inducing point covariance matrix and corresponding mean vector, while we will use the `Adam` optimizer for all other parameters (the kernel hyperparmaeters as well as the inducing point locations). \n",
    "Note that we use a pretty large learning rate for the `NGD` optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "variational_ngd_optimizer = gpytorch.optim.NGD(model.variational_parameters(), num_data=train_y.size(0), lr=0.1)\n",
    "\n",
    "hyperparameter_optimizer = torch.optim.Adam([\n",
    "    {'params': model.hyperparameters()},\n",
    "    {'params': likelihood.parameters()},\n",
    "], lr=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "488a3a3d5f704b0f98611a14e3e6143a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Epoch', style=ProgressStyle(description_width='initial'))…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, description='Minibatch', max=1.0, style=ProgressStyle(description_widt…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model.train()\n",
    "likelihood.train()\n",
    "mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))\n",
    "\n",
    "num_epochs = 1 if smoke_test else 100\n",
    "epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc=\"Epoch\")\n",
    "for i in epochs_iter:\n",
    "    minibatch_iter = tqdm.notebook.tqdm(train_loader, desc=\"Minibatch\", leave=False)\n",
    "    \n",
    "    for x_batch, y_batch in minibatch_iter:\n",
    "        ### Perform NGD step to optimize variational parameters\n",
    "        variational_ngd_optimizer.zero_grad()\n",
    "        hyperparameter_optimizer.zero_grad()\n",
    "        \n",
    "        output = model(x_batch)\n",
    "        loss = -mll(output, y_batch)\n",
    "        minibatch_iter.set_postfix(loss=loss.item())\n",
    "        loss.backward()\n",
    "        variational_ngd_optimizer.step()\n",
    "        hyperparameter_optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization and Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f7ee968da30>]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD8CAYAAABq6S8VAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAfhklEQVR4nO3dfZAU5Z0H8O+PZYyrMaLHqkA0rtGgghshcxoNLmooUe9Q0QgoeqHiFaEihx4nqEXqLlWJlYp4RPcwUS+mYnQjSkSjiRYGMW42EM4FdQURQ9wkCr4QI5rETdyX5/54pplne6d7eqafnumX76dqi6GnX55+md888zy/flqUUiAiouQaUe8CEBFROAzkREQJx0BORJRwDORERAnHQE5ElHAM5ERECRc6kIvIkSLytIhsF5FtInKNjYIREVEwEjaPXETGABijlNoiIgcB2AzgIqXUSzYKSERE/kLXyJVSbyilthRe/xnAdgDjwq6XiIiCGWlzZSJyNIBJADb5zTd69Gh19NFH29w0EVHqbd68+Y9KqSb3dGuBXEQ+CuAhANcqpd4v8f58APMB4KijjkJXV5etTRMRZYKI/L7UdCtZKyKSgw7i7UqpNaXmUUrdpZTKK6XyTU3DvlCIiKhKNrJWBMDdALYrpVaELxIREVXCRo38cwCuBHC2iDxf+DvfwnqJiCiA0G3kSqlOAGKhLEREVAXe2UlElHAM5ETkr/NWoKdj6OueDv0aGPqa6oKBnIiGM4P3uMnA6nnAhpXAuz3Aqrn6b9xkPc/qefo11Q0DORFpXsF71xZgymLgya8Cfb3F+Xt+qee59AdAc2sdCkwOBnIi0pzg3dOhA7MTvN9+CehcAbTMArofAE5doP86bgbyV/kHcfPLwcGmGOsYyImyyh1kneD9o9nA+puGBu9jpwE71wGtS4FNd+i/1qVA193DA7XJ/HIA2BQTEQZyoqwqFWQ7VwAnzNC1bSd4t8wGuh/UQb75jOLyzWfoZhVzHW7NrcV57rlAt62bTTGsnVvBQE6UJWYt3Amyq+bqILt6ng7W7uB92InAOd/QQf7Fh4A57fpv15biOnZt8d5mc6tugul5BhjsK05/9Jpip6mDgb0qDOREWeKuhQM6uPY8o2vgnSt0YDaD97jJwOkL9fRDm3Vgbm4Fplyrlzdfl9LToZtgWpcCI3I6eK+/Cdi2Zvh8bHapSugHS1Qjn88rjn5IVCOdt+rgaDZnrJoLjJ0E7H5OTzt1AbBxJXDWMh20HT0dhawVn0DtxwnOTnNKTwfQPgvo79WBvfkM/X7+Kh3smQHjS0Q2K6Xy7umskROlXala+EChFj7Qp5tJzl4GXP6AroEP6wCtMogD+kvAHZwbckDzVB24AR3EnTZ5cz42swRm9cESRBQTZi3cbAt3auENOeDIU4o1cmBoe7etWrH5JeDUzue0F2vnq+bq95w2+SNa9C8CsyZPZTGQE6WRUws3a8NOW3iuEbj8wWIwNedz/qJQqnYOABMvBmbcpoP4k1/V6Y7v/q4Y8IHwTTwpx6YVojQy0/7W36RrviMKTRojcsPn88s6sWXKtUOD+K4tOljPuE3///SFOm/9ze6h2S3sBC2LnZ1Eabb+Jt3+PLIRmOtRC48Lp1z5q/QNR4DuhGUn6D7s7CRKO/edmj0dOiAe0qzbxB21rIUHZX65nL1M19QH+tgJGhADOVFamNkpZkfiBW06MJqZK2GzUWzzym4Z82ndCbphpZ7GZpaS2NlJlBZmu/jhE/U0s8PQdkaKTX7ZLRtW6k7QN7v1XadsZhmGgZwoTZzb4TtuLtxw0zr0vSQEQHft/PSFOoh3PzB8nwgAm1aI0sW8Hb7cyIRx5c5u6ekojryY1H2KGGvkRGkxLCf8jHhmp1TCa5+OnwGcdMnwgJ/RXHPWyInSwt0kEcfslEp57ZOA45wbmEdOlDTuQbCAbNZGzbzzjOSaM4+cKC341B3N7NjNeK45AzlR0rhvv096O3i1nI5dZ8CtDOeas7OTKIn80gyzwN0J6gy4ldFcc9bIiZLA6/Z7Z1zvrKXklco1dx4Unb8qU0EcYCAnSgav2+9bryv/AOQ08so1b56qv+DcX3opby9n0wpREiT59vuouR9CsWqu/pvTrv+fgQdUMJATJUUabr+PgruZZU67DuQdtwBvbc1EezmbVoiSIg2330fB3czS3KrHMe95JjNpiQzkREngHq87i+3iQWUwLZFNK0RJ4Hf7fcqbDSqS0bREBnKiJCh1632W28W9+A2B2zI7tcMasGmFiNLDKy0x5c0srJETxQkHxLInQ80sVmrkIvJ9EXlbRLbaWB9RZnFALHsydPenrRr5DwCsBPBDS+sb5pHndmH52h3YvbcXow7IQSngvd4+jB3ViCXTx+OiSeOGzDN2VCPOOr4JT7+8Z9//nfkoPdzn3MY59ruOzGvv4MYcRIC9H/hfhxWVqbkVnSffggn3zMW9/Z/HlSOfwrbTb8OU5taqPgOlppvldq/H3Fev+bz2O8jxdC/rtT2/Mpbap9LlOw5jR/VjyfRduGjSOHQ+uQYTup/Avf0zcWXHHdjWfwKmnHNxRddCJee4lvHI2njkInI0gJ8qpSaWm7fS8cgfeW4XblzzInr7Bkq+35hrwCWfGYeHNu/ynMeZ75sXn8RgnhKlrouw57jctebH6zqspEzO9heoVbhm5MO4rX8m7pA5Za9vv20H+WyE4bd/YY6n3/Yq2afGXAOu+9RbmLlzGa7uW4SNgxNw2ohtuD3Xhpc+1+YZzL2uryDnOMh+V3OtJno88uVrd/gekN6+Ady/6bWyJ7W3bwDL1+6wXTyqk1LXRdhzXO5a8+N1HVZSpuVrd+DkgW5c0bAOt/XPxBUN63DyQHfZ69tv20E+G2H47V+Y4+m3vUr2qbdvAH/csXFfEAeAjYMTcHXfIrzwf097Lud1fQU5x0H222Y8qlkgF5H5ItIlIl179uypaNnde3vLzjMQ8JdFkHVRMnidyzDnOOz14XUdeq7XNarhUe934Y7cCqwd+Ed8u/9SLOxbhJW5NpwSoPvJa9tBPxthRHEu/FS6T9/tn7EviDs2Dk7ALX8513MZr7IHOcdB99vW8alZIFdK3aWUyiul8k1NTRUtO3ZUY9l5GkSsrYuSwetchjnHYa8Pr+vQc72uzs3Z+z8LgeCxwdMA6GCzsG8RTh7xatXbDvrZCCOKc+Gn0n2q+Lz4vBdkXUH329bxSUTTypLp49GYa/B8vzHXgMtOPdJ3Hme+JdPH2y4e1Ump6yLsOS53rfnxug59y+R62s95I5/Fvw3+x5Da4/MNLfhz/uqqPgNBPxth+O1fmOPpt71K9qmq8wLv6yvIuoLst814ZCVrRUTuB3AmgNEi8jqA/1JK3W1j3QD2dQaU67HPf+JQZq1kiPu6sHGOS62z0qwV93VYtkzGqIYfaV2KmYdchp0lljfXG/QzUGp6LbNW3MfTZtaK1/EIdV6MPH7nvace/zHGfbAdjx00K/C6yl1Hsc1aqUSlWStEqZbBp8HHlvsmIvf/68wra4V3dhLVkztQNJ8Rq8CROWZTV4K+WBPRRk6UWn6jGlJ9mA/wSMh45gzkRPXkHuQJ0P/nuCr1k8DxzNm0QkTkSOhAW6yRExE5/AbainEzCwM5EZEjoeOZs2mFiKiUBDWzsEZORFRKgppZGMiJiEpJUDMLm1aIiMqJeTMLAzkRUTmlmlne7NbNLK1L695WzkBORFSO+wYtp5mldam+eaj5jLoGc7aRE0XF9dAIAHXvFCMLzGaWs5cVx2Zxn+saYiAniorroRFx6BQjC7zGx/lVW92+uNm0QhSVhI6kR2WUGgfHOadeQ+BGjIGcKErmSHox6BSjCJlf3IdPBHY/B8xpL57zng5dm49gQDQ2rRBFyRlJz+kUq2M7KtWA88Xd8www2FecHnGzGgM5UVRi2ClGETO/uEfkgFVzgfU3Rf6wEAZyIpvMTBWnU8yZzodGpJv7i3tOOzDQp5vVDp9Yen5LHaEM5EQ2mZkqTluo+ZOaD41IL3c2CwA05IDmqbq9fNXcyDKY+PBlItv4MGUq9RDnVXP1e6cuqPq68Hr4MmvkRLaZmSr5qxjEs6hUrvmcdmDspEiuCwZyItuYqUKlnsUKAG9tjeS6YCAnsomZKlRKxNcFAzmRTV63bzNTJdsivi7Y2UlElBDs7CQiSikGciKihGMgJyJKOAZyIqKEYyAnIko4BnKiavAxbhQjDORE1eBj3ChG+IQgomrwMW4UI6yRE1WLg2NRTDCQE1WLg2NRTFgJ5CJyrojsEJGdInKDjXUSxRoHx6IYCR3IRaQBwO0AzgNwIoDLROTEsOslijUOjkUxYqNGfgqAnUqpV5VSHwJYBeBCC+slihcz5dAZb9pMOeRj3KhObATycQBeM/7/emHaECIyX0S6RKRrz549FjZLVGNMOaSYshHIpcS0YWPjKqXuUkrllVL5pqamyrbAmy8oDsyUw/U3DX0mI1Ed2QjkrwM40vj/xwHstrDeItaEKC6YckgxZCOQPwvgOBFpFpH9AMwB8KiF9RaxJkRxwZRDiqHQgVwp1Q9gIYC1ALYDeFAptS3seodhTYjqjSmHFFNW8siVUo8rpT6llPqkUuomG+schjUhqgezf8ZJOXSmM+WQYiIZd3aaNaH9DgSmLB7eZs6OT4qC2T/jpBaa/TNMOaQYSMagWe6bL1bP08HcqQk5QZ7INg6ORQmQjBq5c/MFUPxgda4APvwrOz4peuyfoZhLRiB34weLaon9MxRzyQzkzgereSqw6Y6hHyy2l5NNzFShBEheIDc/WK3X6Wmr5urpvFGIbOPgWJQAyQvk5geruRWY066nd9zC9nIKzz0chJORYv7KY6YKxUzyArnZ8Qno16cuAHqeYXs5hcfhICiBkhfI3dgRRTZxOAhKoGQHct4oRFFgVhQlTLIDudlePm6yzi13bhTiT2KqFn/lUcIkO5DzRiGyjemGlEDJDuRu/ElMYTHdkBIoXYGcP4mpGnwWJyVcegK5+yfx8TOKNwqZ87Dzk9yYckgJl55A7v5JfNIl+t+tD+l/+eEkL0w5pIRLxjC2Qbh/+jp3fa6eBxx4GIcfJX9m/0rrUl4nlCjpqZGXws5PCor9K5Rg6Q7k/HBSEEw5pIRLbyDnh5P88FmclCLpDeTMByY/fBYnpYgopWq+0Xw+r7q6umq+XaIhnF9tfBYnJYSIbFZK5d3T01sjJyqHneGUEgzklF3sDKeUYCCnbGJnOKVI9gK5+1FeAG/dzyJ2hlOKZC+Qc1yN7OLgWJRS2QvkHFcju/glTimVnrFWKsFxNbLJ/BJnyiGlSPZq5ACzFbKMKYeUQtkL5MxWyDZ+iVMKZS+Qm9kKTieXma3ADJb04pc4pVT2Arn5wGan88uZzs6vdGPKIaUUx1rheBvp1nmr/mI2z2lPhw7eTDWkhOFYK17Y+ZVuTDmkDAgVyEXkUhHZJiKDIjLsWyIR2PmVbrxvgDIgbB75VgAXA7jTQlmq8shzu7B87Q7s3tuLsaMasWT6eFw0aVygZY56vwu379eG62UxPvnUK2jI/RP+9Z7L8ZUPF+EPH8vjW5P3YsoBfxjyE9zc3qgDclAKeK+3D2NHNeKs45vw9Mt7Sr7nlMtc/uDGHESAvR/0DZnfnO61T+79NrdtLu+1Da/X1ZTJ65hUul6vefz2z+vYurfxFXUm5nfcjDtxCe6492/Y+8HPPMsa9DqydX1Wcw3b2nbahYkP7mXqeS7LsdJGLiK/AHCdUipQw7etNvJHntuFG9e8iN6+gX3TGnMN+ObFJ3kePHOZLzc8hm51DDYOTsBpI7ZhZa4N3+m/ACMxiG51DG7PteGlz7VhyjkXe24vqMZcAy75zDg8tHlXxcu79ylMOWxxygQgFmXxO7bOub1vYBquaFiHhX2LsHFwQtl1+l1HQQS5Pqu5hm1tO+3CxgdzmVLXV63OpSmVbeTL1+4Y9sHt7RvA8rU7Ai1z58CMfR/ojYMTsLBvEb4y8lEcIH/Dylwbru5bhOu3jPLdXlC9fQO4f9NrVS3v3qcw5bDFKVNcyuJ1bJ0gvrBvEb7dfykW9i3CylwbThuxrew6/a6jIIJcn9Vcw7a2nXZh44O5TKnrq1bnMoiygVxE1onI1hJ/F1ayIRGZLyJdItK1Z8+e6kts2L23t6Lp5d7bODgB9w1MwzUjH8Z9A9OwcXDCkPn9lg1iIMSvH5vlsGX33t7YlMU8tl9ueGxfoG6RV7Gwb9G+6c4Xdou8WnadYfctyPVZzTVsa9tpZzM+eH12a3EugygbyJVS05RSE0v8/aSSDSml7lJK5ZVS+aampupLbBg7qrGi6eXeO23ENlzRsA639c/EFQ3rcNqIbUPm91s2iAaRqpe1WQ5bxo5qjE1ZzGPbrY7ZV+u+c2AGAGBlrg3d6hgA+gvbme4n7L4FuT6ruYZtbTvtbMYHr89uLc5lEIluWlkyfTwacw1DpjXmGrBk+viKlgFK/wS/PdeGb03eW3bZIBpzDbjs1COrWt69T2HKYYtTpriUxTy2Tq17Za4N/z5y9b7zWq5d3L1Ov+soiCDXZzXXsK1tp52t+OD12a3VuQwiVNaKiMwE8D8AmgD8TESeV0pNt1KyAJwOhEp6id3LOBkQn/77q7heFuPl/SdCPujDHz6Wx0uT23TWisey1WSt5D9xaOislVL7Xc+sFa9jUuusFfPYvrz/yfjx4Dm4ZuRDuBOX4OX9T4YEPB62sg2CXJ/VXMO2tp12NuKD1/VVy3MZBO/spHTiHbuUQqnMWiEqiYNjUcYwkFP6cHAsyphsPiGI0sccHMu5E9ccHKu5lU0rlFqskVM6cHAsyjDWyCkd+DxOyjDWyCk9OCQxZRQDOaUHhySmjGIgp3RgyiFlGAM5JVfnrcVA7aQcOtOZckgZwkBOyWVmqjgph2amipmKSJRiDOSVMGuAjp4OPZ1qj49xIwLAQF4Z5irHDzNViBjIK8IaYPwwU4WIgbxirAHGBzNViAAwkFeONcD6YqYK0TAM5JUwa4D7HQhMWTy8zZwdn9FipgrRMBxrpRLu4VFXz9PB3KkBOkGeosMxVYiGYY28Es5wqEAxoHSuAD78Kzs+a4n9FERDMJCHwYBSH+ynIBqCgTwMBpTaY6YK0TAM5NViQKkdZqoQ+WIgr5a743PXlqEdnwCzWGxhpgqRLwbyapkdn4AOKp0risGFt+/bwztqiXwx/dAWpsVFy+xYbl3K40pkYI3cJmaxRIcdy0SeGMhtYrCJBjuWiXwxkNvC2/ftcY/77u5IZqYK0RAM5LaYWSxOx6cTfNjxWRn3uO/ujmSAmSpEBlFK1Xyj+XxedXV11Xy7NeUEb3Z8VofHj2gYEdmslMq7p7NGHhV2fIbD40cUGAN5VNjxGQ6PH1FgDORRYJZF5cwOTuf4TVmsO455/Ih8MZBHwX37PrMsyjM7OJ0sFaeDk8ePyBc7Oyk+2MFJ5IudnRR/7OAkqgoDOcUHOziJqhIqkIvIchF5WUS6ReRhERllq2Cp5L5jEcj2HZ/s4CSyImyN/OcAJiqlWgC8AuDG8EVKMfcdi1m/45MdnERWWOvsFJGZAL6glJpbbt5Md3ayQ28oHg+iwGrR2fklAE9YXF86sUNvKB4PotDKBnIRWSciW0v8XWjMswxAP4B2n/XMF5EuEenas2ePndInETv0huLxIAqtbCBXSk1TSk0s8fcTABCRLwL4ZwBzlU87jVLqLqVUXimVb2pqsrcHScKhbtnBSRSBsFkr5wK4HsAFSqkP7BQpxTjULTs4iSIQqrNTRHYC+AiAdwqTfq2UWlBuuUx3dpqy2tGX1f0mCsmrszPUw5eVUseGWT7zsvpA4azuN1FEeGdnPTkdfc1TgU13DG0bTlN7uftGqJ4Ovb/NU9nBSWQBA3m9mB2frdfpaavm6ulpay8328V7OvR+Anq/2cFJFFqophUKwT3U7Zx2HeA6bgHe2pqudmOnE3P1PODwiXranPbi/jkdnGnZX6IaY428XqZcOzRwNbcCpy4Aep5J540xTrt4zzN6P937zgcpE1WNgTwu0tZeznZxopphII+DNLaXs12cqGbYRh4HaWwvZ7s4Uc2wRh4HaW0vZ7s4UU0wkMdRUgeSYrs4UV2waSVuzPby5lag+Yyh/48zp1380h/o/5vt4kBy9oMoYRjI48bdXm4OJBX3AMh2caK6sPaEoEpw0KwU6by1OHKhY818oPsB3TR09rL6lY0oZWrxhCDKIvdzSDesBLofBFpms12cqEbYtELhmM0px07TQfycbwCnLxze3k9EkWCNPInc2SFAfe/+dNIMux8AWmbpIO5M54MiiCLHQJ5E7uaMet/9aaZL7lw39EuG+eJEkWMgTyKzOWP9TcCPZutHppnNFzZr6H6/AMzmk7OX8fZ7ojpgIE8q8yk7J8zQz72Mqobu9wvAL12SiGqC6YdJ5X7upfMQ48MnArufG5q/ve9BxxU0cbjTCp2Br8ZOSu74L0QJF8kzO6lOvO7+PHaa7nDMNer5Om8FRozUAd6529IvqJvB26mFT1kMDPbr/w/06XFT+JxNolhh00oSlWrOmLIY2P6YDrIjcrr2/PZLwJNf1e/t2qJzvFfPA/7UUxxe9r4vFF//qUe/v2FlIdgv1su//ZJeX0MueeO/EGUAa+RJ5K5N93ToWvflDxRr6O2zCumAs/V7Zo73mJbiOChTry++ntMOjD5OB++WWToDpWVWsZZ/+YPJG/+FKANYI08Ddw0d0LXn5qk6GDtNLi2zCp2ivyzO97f3iq97fqnfd4L3sdP08s1TdS3fwQ5NolhhZ2fauNvPN6wcWsN2gnrrUj1/x81DX7fMNubjXZpEccLOzqwwa+dOk8s539Adlke0FIL6bD1OOKCDuPO6ZXYxeA/26387V+immCSNwkiUMQzkaWO2n7uD+up5Oji/85viPPsfXHydaywGb2e5MS3F4O38EVGsMJCnmVdQ77xVd2wCwK/aiq93bdHNKAzeRInCNnIiooTgeORERCnFQE5ElHAM5ERECcdATkSUcAzkREQJV5esFRHZA+D3FSwyGsAfIypOGHEtFxDfssW1XEB8yxbXcgHxLVtcywWEK9snlFJN7ol1CeSVEpGuUik39RbXcgHxLVtcywXEt2xxLRcQ37LFtVxANGVj0woRUcIxkBMRJVxSAvld9S6Ah7iWC4hv2eJaLiC+ZYtruYD4li2u5QIiKFsi2siJiMhbUmrkRETkITaBXEQuFZFtIjIoIp49uiJyrojsEJGdInKDMf1QEfm5iPym8O8hlspVdr0iMl5Enjf+3heRawvvfU1EdhnvnW+jXEHLVpjvdyLyYmH7XZUuH0W5RORIEXlaRLYXzvs1xntWj5nXNWO8LyLSVni/W0QmB102rABlm1soU7eIbBCRTxvvlTyvNSrXmSLynnGO/jPosjUo2xKjXFtFZEBEDi28F8kxE5Hvi8jbIrLV4/1orzGlVCz+AJwAYDyAXwDIe8zTAOC3AI4BsB+AFwCcWHjvZgA3FF7fAOBblspV0XoLZXwTOt8TAL4G4LqIjlmgsgH4HYDRYffNZrkAjAEwufD6IACvGOfS2jHzu2aMec4H8AQAAfBZAJuCLluDsp0O4JDC6/Ocsvmd1xqV60wAP61m2ajL5pp/BoD1NThmrQAmA9jq8X6k11hsauRKqe1KqR1lZjsFwE6l1KtKqQ8BrAJwYeG9CwHcU3h9D4CLLBWt0vV+HsBvlVKV3PBUrbD7XLdjppR6Qym1pfD6zwC2Axhnafsmv2vGLO8PlfZrAKNEZEzAZSMtm1Jqg1Lq3cJ/fw3g4xa3X3W5Ilo2ivVfBuB+i9svSSnVAeBPPrNEeo3FJpAHNA7Aa8b/X0fxw3+4UuoNQAcJAIdZ2mal652D4RfOwsLPqe/bar6osGwKwJMisllE5lexfFTlAgCIyNEAJgHYZEy2dcz8rply8wRZNoxK138VdK3O4XVea1Wu00TkBRF5QkQmVLhs1GWDiBwA4FwADxmTozpm5UR6jdX0CUEisg7AESXeWqaU+kmQVZSYFjrtxq9cFa5nPwAXALjRmPxdAF+HLufXAfw3gC/VuGyfU0rtFpHDAPxcRF4u1CCqZvGYfRT6g3atUur9wuRQx8y9iRLT3NeM1zyRXG8Btjt8RpGzoAP5FGOy9fNaQbm2QDcf/qXQh/EIgOMCLht12RwzAPxKKWXWlKM6ZuVEeo3VNJArpaaFXMXrAI40/v9xALsLr98SkTFKqTcKP1netlEuEalkvecB2KKUestY977XIvK/AH4atFy2yqaU2l34920ReRj651wH6nzMRCQHHcTblVJrjHWHOmYuftdMuXn2C7BsGEHKBhFpAfA9AOcppd5xpvuc18jLZXzpQin1uIh8R0RGB1k26rIZhv06jvCYlRPpNZa0ppVnARwnIs2F2u8cAI8W3nsUwBcLr78IIEgNP4hK1jusPa4QyBwzAZTs1Y6qbCJyoIgc5LwGcI5RhrodMxERAHcD2K6UWuF6z+Yx87tmzPL+SyGz4LMA3is0CQVZNoyy6xeRowCsAXClUuoVY7rfea1FuY4onEOIyCnQseSdIMtGXbZCmQ4GMBXGtRfxMSsn2mvMdu9ttX/QH9jXAfwdwFsA1hamjwXwuDHf+dAZDr+FbpJxpv8DgKcA/Kbw76GWylVyvSXKdQD0hXywa/l7AbwIoLtwgsZYPGZlywbdG/5C4W9bXI4ZdBOBKhyX5wt/50dxzEpdMwAWAFhQeC0Abi+8/yKMrCmv683iOSxXtu8BeNc4Rl3lzmuNyrWwsN0XoDthT4/LMSv8fx6AVa7lIjtm0BW4NwD0Qcexq2p5jfHOTiKihEta0woREbkwkBMRJRwDORFRwjGQExElHAM5EVHCMZATESUcAzkRUcIxkBMRJdz/AzeiYhewEzSRAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# push training data points through model\n",
    "train_mean_f = model(train_x).loc.data.cpu()\n",
    "# plot training data with y being -1/1 valued\n",
    "plt.plot(train_x.squeeze(-1).cpu(), train_y.mul(2.).sub(1.).cpu(), 'o')\n",
    "# plot mean gaussian process posterior mean evaluated at training data\n",
    "plt.plot(train_x.squeeze(-1).cpu(), train_mean_f.cpu(), 'x')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected the Gaussian Process posterior mean (plotted in orange) gives confident predictions in the regions\n",
    "where the correct label is unambiguous (e.g. for x ~ 0.5) and gives unconfident predictions in regions where\n",
    "the correct label is ambiguous (e.g. x ~ 0.0).\n",
    "\n",
    "We compute the negative log likelihood (NLL) and classification accuracy on the held-out test data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test NLL: 0.3481\n",
      "Test Acc: 0.9000\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "likelihood.eval()\n",
    "with torch.no_grad():\n",
    "    nlls = -likelihood.log_marginal(test_y, model(test_x))\n",
    "    acc = (likelihood(model(test_x)).probs.gt(0.5) == test_y.bool()).float().mean()\n",
    "print('Test NLL: {:.4f}'.format(nlls.mean()))\n",
    "print('Test Acc: {:.4f}'.format(acc.mean()))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
